#ifndef __EQUATION__
#define __EQUATION__

#include "Mesh.h"
#include "Element.h"
#include "Point.h"
#include <iostream>
#include <string>
#include <cmath>
#include <fstream>

#define FEM_TEMPLATE template <typename Solution_Type, typename Mesh_Type, typename Element_Type, typename Matrix_Type, int DIM>
#define FEM_TEMPLATE_INSIDE Solution_Type, Mesh_Type, Element_Type, Matrix_Type, DIM
#define IRREGULAR_MESH_TEMPLATE template <typename Solution_Type, typename Mesh_Type, typename Element_Type, typename Matrix_Type>
#define IRREGULAR_MESH_TEMPLATE_INSIDE Solution_Type, Mesh_Type, Element_Type, Matrix_Type
#define REGULAR_MESH_TEMPLATE IRREGULAR_MESH_TEMPLATE
#define REGULAR_MESH_TEMPLATE_INSIDE IRREGULAR_MESH_TEMPLATE_INSIDE

#include <Eigen/Sparse>
#include <Eigen/IterativeLinearSolvers>
typedef Eigen::SparseMatrix<double> SpMat;
typedef Eigen::Triplet<double> Tri;
typedef Eigen::VectorXd RHS;
typedef Eigen::VectorXd Solution;

/////////////////////////////////////////////////////////////
/// FEM(2D)-|
///                    |-Equations                                                 |-Real Problem=Equation+Condition+Mesh+Solver
///                    |--Possion
///                    |--...(TODO)
///                    |
///                    |-Boundary Conditions
///                    |--Dirichlet Condition
///                    |--Neumann Condition(TODO)
///                    |--Third kind Boundary Condition(TODO)
///                    |--Mixed(TODO)
///                    |--...
///                    |
///                    |-Meshs
///                    |--Triangle Mesh
///                    |---Irregular Mesh
///                    |----Regular Mesh(regard as a special kind of Irregular Mesh)
///                    |--Rectangle Mesh(TODO)
///                    |--...
///                    |
///                    |-Solvers
///                    |--Iterative solver(CG)
///                    |--Multi-Grid(TODO)
///                    |--...
//////////////////////////////////////////////////////////////

template <typename Solution_Type, typename Mesh_Type, typename Element_Type, typename Matrix_Type, int DIM>
class FEM;

///Equations
template <typename Mesh_Type, typename Element_Type, int DIM>
class FEM2D_Possion;

///Conditions
template <typename Mesh_Type, typename Element_Type, int DIM>
class FEM2D_Dirichlet_Condition;

///Solvers
template <typename Mesh_Type, typename Element_Type>
class FEM2D_CG_Solver;

///Meshs
template <typename Solution_Type, typename Mesh_Type, typename Element_Type, typename Matrix_Type>  //as it will
class FEM2D_IrregularMesh;
template <typename Solution_Type, typename Mesh_Type, typename Element_Type, typename Matrix_Type>
class FEM2D_RegularMesh;

template <typename Solution_Type, typename Mesh_Type, typename Element_Type, typename Matrix_Type, int DIM>
class FEM
{
protected:
    Solution_Type solution;
    Mesh_Type mesh;
    Element_Type ref_ele;

public:
    virtual void Initial() = 0;
    virtual void Solver(Matrix_Type &K, Solution_Type &Rhs, double tol) = 0; //As Matrix_type is unknown, here we set this function as pure virtual
    virtual void Output() = 0;
    virtual void Assemble(Matrix_Type &K, Solution_Type &Rhs, double f(const Point<DIM, double> &), double time_interval);                                 //f is the nonhomogeneous term of differential equation
    virtual void Assemble_Core(size_t _ei, Solution_Type &Rhs, double f(const Point<DIM, double> &), std::vector<Tri>::iterator& it, double time_interval)=0; //刀耕火种版本。。。。
    virtual void Boundary_Condition(Matrix_Type &K, Solution_Type &Rhs, double (*fai)(const Point<DIM, double> &)) = 0;                                    //fai is the boundary conditions (as for the mixed boundary conditions, fai is a function pointer array)
    virtual void Take_one_step(int _acc, double f(const Point<DIM, double> &), double (*fai)(const Point<DIM, double> &), double time_interval);
    virtual void Take_steps(int _acc, double f(const Point<DIM, double> &), double (*fai)(const Point<DIM, double> &), double time_interval, int time_steps);
    virtual double L2_error(int _acc, double u(const Point<DIM, double> &));
    Mesh_Type &get_mesh() { return mesh; }
    Solution_Type &get_solution() { return solution; }
    Element_Type &get_ref_ele() { return ref_ele; }
    void set_mesh(Mesh_Type &_mesh) { mesh = _mesh; }
};

FEM_TEMPLATE
void FEM<FEM_TEMPLATE_INSIDE>::Take_one_step(int _acc, double f(const Point<DIM, double> &), double (*fai)(const Point<DIM, double> &), double time_interval)
{
    if (ref_ele.get_algebraic_accuracy() < _acc)
        ref_ele.read_quad_info("data/triangle.tmp_geo", _acc);

    size_t n_total_dofs = mesh.get_n_dofs();

    Matrix_Type K(n_total_dofs, n_total_dofs);
    Solution_Type Rhs;
    Rhs.setZero(n_total_dofs);

    Assemble(K, Rhs, f, time_interval);
    Boundary_Condition(K, Rhs, fai);
    Solver(K, Rhs, 1e-15);
}

FEM_TEMPLATE
void FEM<FEM_TEMPLATE_INSIDE>::Take_steps(int _acc, double f(const Point<DIM, double> &), double (*fai)(const Point<DIM, double> &), double time_interval, int time_steps)
{
    for (int i = 0; i < time_steps; i++)
        Take_one_step(_acc, f, fai, time_interval);
}

FEM_TEMPLATE
void FEM<FEM_TEMPLATE_INSIDE>::Assemble(Matrix_Type &K, Solution_Type &Rhs, double f(const Point<DIM, double> &), double time_interval)
{
    int n_quad_pnts = ref_ele.get_n_quad_points();
    size_t n_ele = mesh.get_n_grids();
    size_t n_dof = ref_ele.get_n_dofs();
    size_t n_total_dofs = mesh.get_n_dofs();

    std::vector<Tri> TriList(n_ele * n_dof * n_dof * n_quad_pnts);
    std::vector<Tri>::iterator it = TriList.begin();
    std::cout << "its size:" << TriList.size() << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        Assemble_Core(ei,Rhs, f,it, time_interval);
    }

    K.setFromTriplets(TriList.begin(), TriList.end());
    K.makeCompressed();
}

FEM_TEMPLATE
double FEM<FEM_TEMPLATE_INSIDE>::L2_error(int _acc, double u(const Point<DIM, double> &))
{
    size_t n_ele = mesh.get_n_grids();
    size_t n_dof = ref_ele.get_n_dofs();
    ///L2error
    if (ref_ele.get_algebraic_accuracy() < _acc)
        ref_ele.read_quad_info("data/triangle.tmp_geo", _acc);
    /// Get the quadrature information from the reference element.
    const std::valarray<Point<DIM, double>> &quad_pnts = ref_ele.get_quad_points();
    const std::valarray<double> &weights = ref_ele.get_quad_weights();
    int n_quad_pnts = ref_ele.get_n_quad_points();
    double volume = ref_ele.get_volume();

    double error = 0.0;
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        /// Find the structure of the current element.
        const Geometry &the_ele = mesh.get_grid(ei);
        /// Find the global indexes of all the degree of freedoms on the
        /// current element.
        const std::valarray<size_t> &dof = the_ele.get_dof();
        /// Within the indexes, we find the coordinates of the dofs
        /// from the mesh, and set to the reference element. Actually
        /// build a mapping between the global dofs and the local
        /// ones. (The indexes and coordinates of local dofs have set
        /// inside P1 element.)
        for (size_t k = 0; k < n_dof; k++)
            ref_ele.set_global_dof(k, mesh.get_dof(dof[k]));

        /// Begin to assemble ...
        for (size_t l = 0; l < n_quad_pnts; l++)
        {
            ///
            double detJ = ref_ele.global2local_Jacobi_det(quad_pnts[l]);
            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = ref_ele.basis_function(i, quad_pnts[l]);
                u_sol += phi_i * solution[dof[i]];
            }
            error += (u(ref_ele.local2global(quad_pnts[l])) - u_sol) * (u(ref_ele.local2global(quad_pnts[l])) - u_sol) * weights[l] * detJ * volume;
        }
    }
    error = std::sqrt(error);
    std::cout << "L2 error: " << error << std::endl
              << std::endl;
    return error;
}

template <typename Mesh_Type, typename Element_Type, int DIM>
class FEM2D_Possion : virtual public FEM<Solution, Mesh_Type, Element_Type, SpMat, DIM>
{
public:
    void Take_one_step(int _acc, double f(const Point<DIM, double> &), double (*fai)(const Point<DIM, double> &)) { FEM<Solution, Mesh_Type, Element_Type, SpMat, DIM>::Take_one_step(_acc, f, fai, 1.); };
    void Take_steps(int _acc, double f(const Point<DIM, double> &), double (*fai)(const Point<DIM, double> &)) { FEM<Solution, Mesh_Type, Element_Type, SpMat, DIM>::Take_steps(_acc, f, fai, 1., 1); };
    virtual void Assemble_Core(size_t _ei, Solution &Rhs, double f(const Point<DIM, double> &), std::vector<Tri>::iterator& it, double time_interval);
};

template <typename Mesh_Type, typename Element_Type, int DIM>
void FEM2D_Possion<Mesh_Type, Element_Type, DIM>::Assemble_Core(size_t _ei, Solution &Rhs, double f(const Point<DIM, double> &), std::vector<Tri>::iterator& it, double time_interval)
{
    const std::valarray<Point<DIM, double>> &quad_pnts = this->ref_ele.get_quad_points();
    const std::valarray<double> &weights = this->ref_ele.get_quad_weights();
    int n_quad_pnts = this->ref_ele.get_n_quad_points();
    static size_t n_dof = this->ref_ele.get_n_dofs();
    static double volume = this->ref_ele.get_volume();

    const Geometry &the_ele = this->mesh.get_grid(_ei);
    const std::valarray<size_t> &dof = the_ele.get_dof();
    for (size_t k = 0; k < n_dof; k++)
        this->ref_ele.set_global_dof(k, this->mesh.get_dof(dof[k]));

    for (size_t l = 0; l < n_quad_pnts; l++)
    {
        double detJ = this->ref_ele.global2local_Jacobi_det(quad_pnts[l]);
        Matrix<double> invJ = this->ref_ele.local2global_Jacobi(quad_pnts[l]);

        for (size_t i = 0; i < n_dof; i++)
        {
            std::valarray<double> grad_i = invJ * this->ref_ele.basis_gradient(i, quad_pnts[l]);
            double phi_i = this->ref_ele.basis_function(i, quad_pnts[l]);
            for (size_t j = 0; j < n_dof; j++)
            {
                std::valarray<double> grad_j = invJ * this->ref_ele.basis_gradient(j, quad_pnts[l]);
                double cont = (grad_i[0] * grad_j[0] + grad_i[1] * grad_j[1]) * weights[l] * detJ * volume;
                // A(dof[i], dof[j]) += cont;
                *it = Tri(dof[i], dof[j], cont);
                it++;
            }
            /// rhs(i) = \int f(x_i) phi_i dx
            // rhs[dof[i]] += f(ref_ele.local2global(quad_pnts[l])) * phi_i * weights[l] * detJ * volume;

            Rhs[dof[i]] += f(this->ref_ele.local2global(quad_pnts[l])) * phi_i * weights[l] * detJ * volume;
        }
    }
}

template <typename Mesh_Type, typename Element_Type, int DIM>
class FEM2D_Dirichlet_Condition : virtual public FEM<Solution, Mesh_Type, Element_Type, SpMat, DIM>
{
public:
    virtual void Boundary_Condition(SpMat &K, Solution &Rhs, double (*fai)(const Point<DIM, double> &));
};

template <typename Mesh_Type, typename Element_Type, int DIM>
void FEM2D_Dirichlet_Condition<Mesh_Type, Element_Type, DIM>::Boundary_Condition(SpMat &K, Solution &Rhs, double (*fai)(const Point<DIM, double> &))
{
    for (size_t i = 0; i < this->mesh.get_n_dofs(); i++)
    {
        const Point<DIM, double> &global_dof = this->mesh.get_dof(i);
        int bm = global_dof.get_boundary_mark();
        if (bm == 0)
            continue;
        else if (bm == 1)
        {
            double boundary_value = fai[0](global_dof);
            Rhs[i] = boundary_value * K.coeffRef(i, i);
            for (Eigen::SparseMatrix<double>::InnerIterator it(K, i); it; ++it)
            {
                size_t row = it.row();
                if (row == i)
                    continue;
                Rhs[row] -= K.coeffRef(i, row) * boundary_value;
                K.coeffRef(i, row) = 0.0;
                K.coeffRef(row, i) = 0.0;
            }
        }
    }
}

template <typename Mesh_Type, typename Element_Type>
class FEM2D_CG_Solver : virtual public FEM<Solution, Mesh_Type, Element_Type, SpMat, 2>
{
public:
    virtual void Solver(SpMat &K, Solution &Rhs, double tol);
};
template <typename Mesh_Type, typename Element_Type>
void FEM2D_CG_Solver<Mesh_Type, Element_Type>::Solver(SpMat &K, Solution &Rhs, double tol)
{
    Eigen::ConjugateGradient<Eigen::SparseMatrix<double>> Solver_sparse;
    Solver_sparse.setTolerance(tol);
    Solver_sparse.compute(K);
    this->solution = Solver_sparse.solve(Rhs);
}

IRREGULAR_MESH_TEMPLATE
class FEM2D_IrregularMesh : virtual public FEM<Solution_Type, Mesh_Type, Element_Type, Matrix_Type, 2>
{
public:
    virtual void Initial();
    virtual void Initial(const std::string &filename) { Initial_Easymesh(filename); }
    void Initial_Easymesh(const std::string &filename);
    virtual void Output();
    virtual void Output(const std::string &filename);
    virtual void Output(const std::string &filename, unsigned int rec);
    void Irregularmesh2D_Matlab_plot(const std::string &filename, unsigned int rec);
    void Irregularmesh2D_OpenDx_plot(const std::string &filename, unsigned int rec);
    void refinement(std::valarray<Point<2, double>> &_nodes,
                    int local_n1_index, int local_n2_index, int local_n3_index,
                    size_t new_index[3], int rec, std::valarray<std::valarray<size_t>> &connections);
};

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::Initial()
{
    Initial_Easymesh("data/D");
}

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::Initial_Easymesh(const std::string &filename)
{
    this->mesh.readData(filename);
}

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::Output()
{
    Output("u1", 0);
}

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::Output(const std::string &filename)
{
    Output(filename, 0);
}

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::Output(const std::string &filename, unsigned int rec)
{
    Irregularmesh2D_Matlab_plot(filename, rec);
    Irregularmesh2D_OpenDx_plot(filename, rec);
}

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::refinement(std::valarray<Point<2, double>> &_nodes,
                                                                     int local_n1_index, int local_n2_index, int local_n3_index,
                                                                     size_t new_index[3], int rec, std::valarray<std::valarray<size_t>> &connections)
{
    if (local_n1_index == local_n2_index || local_n1_index == local_n3_index || local_n2_index == local_n3_index)
        std::cerr << "two equal node index" << std::endl;
    if (rec > 0)
    {
        int new_node1_index = new_index[0]++;
        int new_node2_index = new_index[0]++;
        int new_node3_index = new_index[0]++;

        _nodes[new_node1_index] = mid_point<2, double>(_nodes[local_n2_index], _nodes[local_n3_index]);
        _nodes[new_node1_index].set_index(new_index[1]++);
        _nodes[new_node2_index] = mid_point<2, double>(_nodes[local_n1_index], _nodes[local_n3_index]);
        _nodes[new_node2_index].set_index(new_index[1]++);
        _nodes[new_node3_index] = mid_point<2, double>(_nodes[local_n1_index], _nodes[local_n2_index]);
        _nodes[new_node3_index].set_index(new_index[1]++);

        refinement(_nodes, local_n1_index, new_node2_index, new_node3_index, new_index, rec - 1, connections);
        refinement(_nodes, local_n2_index, new_node1_index, new_node3_index, new_index, rec - 1, connections);
        refinement(_nodes, local_n3_index, new_node1_index, new_node2_index, new_index, rec - 1, connections);
        refinement(_nodes, new_node1_index, new_node2_index, new_node3_index, new_index, rec - 1, connections);
    }
    else
    {
        connections[new_index[2]][0] = _nodes[local_n1_index].get_index();
        connections[new_index[2]][1] = _nodes[local_n2_index].get_index();
        connections[new_index[2]][2] = _nodes[local_n3_index].get_index();
        new_index[2]++;
    }
}

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::Irregularmesh2D_Matlab_plot(const std::string &filename, unsigned int rec)
{
    size_t n_ele = this->mesh.get_n_grids();
    int side_per_grid = pow(3, rec + 1);
    int node_per_grid = 2 + pow(4, rec); //3+3*(1+4+16+...)
    int ele_per_grid = pow(4, rec);
    int n_dof = this->ref_ele.get_n_dofs();

    std::valarray<std::valarray<Point<2, double>>> nodes;
    std::valarray<std::valarray<std::valarray<size_t>>> connections;
    std::valarray<std::valarray<double>> values;
    nodes.resize(n_ele);
    connections.resize(n_ele);
    values.resize(n_ele);

    size_t node_index[3] = {0, 0, 0}; //local_index,global_index,local_element_index
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        nodes[ei].resize(node_per_grid);
        connections[ei].resize(ele_per_grid);
        values[ei].resize(node_per_grid);

        const Geometry &the_ele = this->mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        node_index[0] = 3;
        node_index[2] = 0;
        for (size_t k = 0; k < n_dof; k++)
        {
            this->ref_ele.set_global_dof(k, this->mesh.get_dof(dof[k]));
        }
        for (size_t k = 0; k < 3; k++)
        {
            nodes[ei][k] = this->mesh.get_dof(dof[k]);
            nodes[ei][k].set_index(node_index[1]++);
        }
        refinement(nodes[ei], 0, 1, 2, node_index, rec, connections[ei]);
        for (size_t k = 0; k < node_per_grid; k++)
        {
            Point<2, double> local_pnt = this->ref_ele.global2local(nodes[ei][k]);

            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = this->ref_ele.basis_function(i, local_pnt);
                u_sol += phi_i * this->solution[dof[i]];
            }
            values[ei][k] = u_sol;
        }
    }

    /// Output the solution to file in Matlab format.
    std::ofstream output((filename + ".m").c_str());
    output << "T=zeros(" << n_ele * ele_per_grid << ","
           << "3);" << std::endl;
    output << "P=zeros(" << n_ele * node_per_grid << ","
           << "3);" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << "P(" << ei * node_per_grid + k + 1 << ",:)=[" << nodes[ei][k][0] << "\t" << nodes[ei][k][1] << "\t" << values[ei][k] << "];" << std::endl;
        }
    }

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < ele_per_grid; k++)
        {
            output << "T(" << ei * ele_per_grid + k + 1 << ",:)=["
                   << connections[ei][k][0] + 1 << "\t" << connections[ei][k][1] + 1 << "\t" << connections[ei][k][2] + 1 << "];" << std::endl;
        }
    }

    output << "TR=triangulation(T,P(:,1),P(:,2),P(:,3));" << std::endl;
    output << "trisurf(TR);" << std::endl;
    output.close();
}

IRREGULAR_MESH_TEMPLATE
void FEM2D_IrregularMesh<IRREGULAR_MESH_TEMPLATE_INSIDE>::Irregularmesh2D_OpenDx_plot(const std::string &filename, unsigned int rec)
{
    if (rec < 0)
    {
        std::cerr << "reccurance cannot be negative!!" << std::endl;
        rec = 0;
    }
    size_t n_ele = this->mesh.get_n_grids();
    int n_dof = this->ref_ele.get_n_dofs();
    int side_per_grid = pow(3, rec + 1);
    int node_per_grid = 2 + pow(4, rec); //3+3*(1+4+16+...)
    int ele_per_grid = pow(4, rec);

    std::valarray<std::valarray<Point<2, double>>> nodes;
    std::valarray<std::valarray<std::valarray<size_t>>> connections;
    std::valarray<std::valarray<double>> values;
    nodes.resize(n_ele);
    connections.resize(n_ele);
    values.resize(n_ele);

    size_t node_index[3] = {0, 0, 0}; //local_index,global_index,local_element_index
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        nodes[ei].resize(node_per_grid);
        connections[ei].resize(ele_per_grid);
        values[ei].resize(node_per_grid);

        const Geometry &the_ele = this->mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        node_index[0] = 3;
        node_index[2] = 0;
        for (size_t k = 0; k < n_dof; k++)
        {
            this->ref_ele.set_global_dof(k, this->mesh.get_dof(dof[k]));
        }
        for (size_t k = 0; k < 3; k++)
        {
            nodes[ei][k] = this->mesh.get_dof(dof[k]);
            nodes[ei][k].set_index(node_index[1]++);
        }
        refinement(nodes[ei], 0, 1, 2, node_index, rec, connections[ei]);
        for (size_t k = 0; k < node_per_grid; k++)
        {
            Point<2, double> local_pnt = this->ref_ele.global2local(nodes[ei][k]);

            double u_sol = 0;
            for (size_t i = 0; i < n_dof; i++)
            {
                double phi_i = this->ref_ele.basis_function(i, local_pnt);
                u_sol += phi_i * this->solution[dof[i]];
            }
            values[ei][k] = u_sol;
        }
    }

    /// Output the solution to file in OpenDx format.
    std::ofstream output((filename + ".dx").c_str());
    output << "object 1 class array type float rank 1 shape 2 item "
           << n_ele * node_per_grid
           << " data follows" << std::endl;
    output.setf(std::ios::fixed);

    output.precision(20);
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << nodes[ei][k][0] << "\t" << nodes[ei][k][1] << std::endl;
        }
    }

    output << std::endl;
    output << "object 2 class array type int rank 1 shape 3 item "
           << n_ele * ele_per_grid << " data follows" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < ele_per_grid; k++)
        {
            output << connections[ei][k][0] << "\t" << connections[ei][k][1] << "\t" << connections[ei][k][2] << "\t" << std::endl;
        }
    }
    output << "attribute \"element type\" string \"triangles\"" << std::endl;
    output << "attribute \"ref\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object 3 class array type float rank 0 item "
           << n_ele * node_per_grid
           << " data follows" << std::endl;
    for (size_t ei = 0; ei < n_ele; ei++)
    {
        for (size_t k = 0; k < node_per_grid; k++)
        {
            output << values[ei][k] << std::endl;
        }
    }

    output << "attribute \"dep\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object \"FEMFunction-2d\" class field" << std::endl;
    output << "component \"positions\" value 1" << std::endl;
    output << "component \"connections\" value 2" << std::endl;
    output << "component \"data\" value 3" << std::endl;
    output << "end" << std::endl;

    output.close();
}

REGULAR_MESH_TEMPLATE
class FEM2D_RegularMesh : virtual public FEM2D_IrregularMesh<Solution_Type, Mesh_Type, Element_Type, Matrix_Type>
{
public:
    void Initial();
    void Initial(const Point<2, double> &_lbc, const Point<2, double> &_ruc, int _nx, int _ny) { Initial_Regularmesh(_lbc, _ruc, _nx, _ny); };
    void Initial_Regularmesh(const Point<2, double> &_lbc, const Point<2, double> &_ruc, int _nx, int _ny);
    void Output(const std::string &filename, unsigned int rec);
    void Output(const std::string &filename){Output(filename,1);};
    void Regularmesh2D_Matlab_plot(const std::string &filename, unsigned int multiple);
    void Regularmesh2D_OpenDx_plot(const std::string &filename, unsigned int multiple);
};

REGULAR_MESH_TEMPLATE
void FEM2D_RegularMesh<REGULAR_MESH_TEMPLATE_INSIDE>::Initial()
{
    int nx = 2, ny = 2;
    Point<2, double> x0y0({0, 0}), x1y1({1, 1});
    Initial_Regularmesh(x0y0, x1y1, nx, ny);
}

REGULAR_MESH_TEMPLATE
void FEM2D_RegularMesh<REGULAR_MESH_TEMPLATE_INSIDE>::Initial_Regularmesh(const Point<2, double> &_lbc, const Point<2, double> &_ruc, int _nx, int _ny)
{
    this->mesh.set_lbc(_lbc);
    this->mesh.set_ruc(_ruc);
    this->mesh.set_nx(_nx);
    this->mesh.set_ny(_ny);
}

REGULAR_MESH_TEMPLATE
void FEM2D_RegularMesh<REGULAR_MESH_TEMPLATE_INSIDE>::Output(const std::string &filename, unsigned int rec)
{
    Regularmesh2D_Matlab_plot(filename, rec);
    Regularmesh2D_OpenDx_plot(filename, rec);
}

REGULAR_MESH_TEMPLATE
void FEM2D_RegularMesh<REGULAR_MESH_TEMPLATE_INSIDE>::Regularmesh2D_Matlab_plot(const std::string &filename, unsigned int multiple)
{
    if (multiple == 0)
    {
        std::cerr << "multiple cannot be 0 !!" << std::endl;
        multiple = 1;
    }
    /// Output the solution to file in Matlab format.
    std::ofstream output((filename + ".m").c_str());
    size_t n_ele = this->mesh.get_n_grids();
    size_t nx = this->mesh.get_nx(), ny = this->mesh.get_ny();
    double hx = this->mesh.get_hx(), hy = this->mesh.get_hy();
    Element_Type ref_ele;
    int n_dof = ref_ele.get_n_dofs();
    output
        << "x=linspace(" << this->mesh.get_lbc()[0] << "," << this->mesh.get_ruc()[0] << "," << multiple * nx + 1 << ");" << std::endl;
    output << "y=linspace(" << this->mesh.get_lbc()[1] << "," << this->mesh.get_ruc()[1] << "," << multiple * ny + 1 << ");" << std::endl;
    output << "z=zeros(" << multiple * ny + 1 << "," << multiple * nx + 1 << ");" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        const Geometry &the_ele = this->mesh.get_grid(ei);
        const std::valarray<size_t> &dof = the_ele.get_dof();
        for (size_t k = 0; k < n_dof; k++)
        {
            ref_ele.set_global_dof(k, this->mesh.get_dof(dof[k]));
        }
        if (ei % 2 == 0)
        {
            size_t ix = (ei / 2) % nx;
            size_t iy = (ei / 2) / nx;
            for (int iix = 0; iix < multiple; iix++)
                for (int iiy = 0; iiy <= multiple - iix; iiy++)
                {
                    Point<2, double> global_pnt({this->mesh.get_lbc()[0]+(ix + 1. * iix / multiple) * hx, this->mesh.get_lbc()[1]+(iy + 1. * iiy / multiple) * hy});
                    Point<2, double> local_pnt = this->ref_ele.global2local(global_pnt);
                    double u_sol = 0;
                    for (size_t i = 0; i < n_dof; i++)
                    {
                        double phi_i = this->ref_ele.basis_function(i, local_pnt);
                        u_sol += phi_i * this->solution[dof[i]];
                    }
                    output << "z(" << iy * multiple + iiy + 1 << "," << ix * multiple + iix + 1 << ")=" << u_sol << ";" << std::endl;
                }
        }
        else
        {
            size_t ix = ((ei - 1) / 2) % nx;
            size_t iy = ((ei - 1) / 2) / nx;
            for (int iix = 0; iix < multiple; iix++)
                for (int iiy = multiple - iix; iiy < multiple; iiy++)
                {
                    Point<2, double> global_pnt({this->mesh.get_lbc()[0]+(ix + 1. * iix / multiple) * hx, this->mesh.get_lbc()[1]+(iy + 1. * iiy / multiple) * hy});
                    Point<2, double> local_pnt = this->ref_ele.global2local(global_pnt);
                    double u_sol = 0;
                    for (size_t i = 0; i < n_dof; i++)
                    {
                        double phi_i = this->ref_ele.basis_function(i, local_pnt);
                        u_sol += phi_i * this->solution[dof[i]];
                    }
                    output << "z(" << iy * multiple + iiy + 1 << "," << ix * multiple + iix + 1 << ")=" << u_sol << ";" << std::endl;
                }
        }
    }
    // output << "surf(x,y,z);" << std::endl;
    // output << "shading interp" << std::endl;
    output << "[x,y]=meshgrid(x,y);" << std::endl;
    output << " T = delaunay(x,y);" << std::endl;
    output << "trisurf(T,x,y,z);" << std::endl;
    output.close();
}

REGULAR_MESH_TEMPLATE
void FEM2D_RegularMesh<REGULAR_MESH_TEMPLATE_INSIDE>::Regularmesh2D_OpenDx_plot(const std::string &filename, unsigned int multiple)
{
    /// Output the solution to file in OpenDx format.
    std::ofstream output((filename + ".dx").c_str());
    size_t nx = this->mesh.get_nx(), ny = this->mesh.get_ny();
    Element_Type ref_ele;
    nx *= multiple;
    ny *= multiple;
    RegularMesh meshtmp;
        meshtmp.set_lbc(this->mesh.get_lbc());
    meshtmp.set_ruc(this->mesh.get_ruc());
    meshtmp.set_nxny(nx, ny);
    size_t n_node = meshtmp.get_n_points();
    size_t n_ele = meshtmp.get_n_grids();
    int n_dof = this->ref_ele.get_n_dofs();

    output << "object 1 class array type float rank 1 shape 2 item "
           << n_node
           << " data follows" << std::endl;
    output.setf(std::ios::fixed);

    output.precision(20);
    for (size_t ni = 0; ni < n_node; ni++)
    {
        Point<2, double> pnt = meshtmp.get_point(ni);
        output << pnt[0] << "\t" << pnt[1] << std::endl;
    }

    output << std::endl;
    output << "object 2 class array type int rank 1 shape 3 item " << n_ele << " data follows" << std::endl;

    for (size_t ei = 0; ei < n_ele; ei++)
    {
        const Geometry &the_ele = meshtmp.get_grid(ei);
        output << the_ele.get_vertex(0) << "\t" << the_ele.get_vertex(1) << "\t" << the_ele.get_vertex(2) << std::endl;
    }
    output << "attribute \"element type\" string \"triangles\"" << std::endl;
    output << "attribute \"ref\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object 3 class array type float rank 0 item "
           << n_node
           << " data follows" << std::endl;

    double hx = meshtmp.get_hx(), hy = meshtmp.get_hy();
    size_t last_ei = 999;
    Geometry the_ele;
    std::valarray<size_t> dof;
    for (size_t ni = 0; ni < n_node; ni++)
    {
        size_t ix = (ni) % (nx + 1); //index of node
        size_t iy = (ni) / (nx + 1);
        Point<2, double> global_pnt({ix * hx, iy * hy});
        int is_upper = (ix % multiple + iy % multiple) > multiple; //wheater node in the upper triangle
        ix /= multiple;                                            //index of element
        iy /= multiple;
        if (ix == nx / multiple) //right boundary,upper boundary
        {
            ix--;
            is_upper = 1;
        }
        if (iy == ny / multiple)
        {
            iy--;
            is_upper = 1;
        }

        size_t ei = (2 * ix) + iy * (2 * nx / multiple) + is_upper;
        if (ei != last_ei)
        {
            the_ele = this->mesh.get_grid(ei);
            dof = the_ele.get_dof();
            for (size_t k = 0; k < n_dof; k++)
            {
                this->ref_ele.set_global_dof(k, this->mesh.get_dof(dof[k]));
            }
            last_ei = ei;
        }
        Point<2, double> local_pnt = this->ref_ele.global2local(global_pnt);
        double u_sol = 0;
        for (size_t i = 0; i < n_dof; i++)
        {
            double phi_i = this->ref_ele.basis_function(i, local_pnt);
            u_sol += phi_i * this->solution[dof[i]];
        }
        output << u_sol << std::endl;
    }

    output << "attribute \"dep\" string \"positions\"" << std::endl;
    output << std::endl;
    output << "object \"FEMFunction-2d\" class field" << std::endl;
    output << "component \"positions\" value 1" << std::endl;
    output << "component \"connections\" value 2" << std::endl;
    output << "component \"data\" value 3" << std::endl;
    output << "end" << std::endl;
    output.close();
}

#undef FEM_TEMPLATE
#undef FEM_TEMPLATE_INSIDE
#undef IRREGULAR_MESH_TEMPLATE
#undef IRREGULAR_MESH_TEMPLATE_INSIDE
#undef REGULAR_MESH_TEMPLATE
#undef REGULAR_MESH_TEMPLATE_INSIDE

#else
// DO NOTHING.
#endif